{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a classifier using model2vec\n",
    "\n",
    "Model2Vec supports built-in classifier training with an easy, scikit-learn-based syntax. Just give the model your data in `.fit`, and you'll have a trained model!\n",
    "\n",
    "How it works:\n",
    "* We load a base `StaticModel` using as a torch module. By default we use [potion-base-8m](https://huggingface.co/minishlab/potion-base-8M).\n",
    "* We add a one-layer MLP with 512 hidden units and `ReLU` activation as a head.\n",
    "* We train the model using cross-entropy, using [`pytorch-lightning`](https://lightning.ai/docs/pytorch/stable/) as a training framework.\n",
    "\n",
    "After training, you can export the model using regular torch tools, such as `torch.save` and `torch.load`, or you can export the model to a `scikit-learn` pipeline. The latter option leads to a really small footprint during inference, as there is no longer a need to use `torch`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2mUsing Python 3.11.4 environment at: /Users/stephantulkens/Documents/GitHub/model2vec/.venv\u001b[0m\n",
      "\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 4ms\u001b[0m\u001b[0m\n",
      "\u001b[2mUsing Python 3.11.4 environment at: /Users/stephantulkens/Documents/GitHub/model2vec/.venv\u001b[0m\n",
      "\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 8ms\u001b[0m\u001b[0m\n",
      "\u001b[2mUsing Python 3.11.4 environment at: /Users/stephantulkens/Documents/GitHub/model2vec/.venv\u001b[0m\n",
      "\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 3ms\u001b[0m\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# Install the necessary libraries\n",
    "!uv pip install \"model2vec[train,inference]\"\n",
    "!uv pip install \"datasets\"\n",
    "!uv pip install \"scikit-learn\"\n",
    "\n",
    "# Import the necessary libraries\n",
    "from model2vec.train import StaticModelForClassification\n",
    "from model2vec.inference import StaticModelPipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To demonstrate how to train a model, we'll be using the `20_newsgroups` dataset, which contains posts from 1 of 20 newsgroups."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Repo card metadata block was not found. Setting CardData to empty.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['text', 'label', 'label_text'],\n",
      "        num_rows: 11314\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['text', 'label', 'label_text'],\n",
      "        num_rows: 7532\n",
      "    })\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"setfit/20_newsgroups\")\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at the first five training samples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TEXT: I was wondering if anyone out there could enlighten me on this car I saw\n",
      "the other day. It was a 2-door sports car, looked to be from the late 60s/\n",
      "early 70s. It was called a Bricklin. The doors were really small. In addition,\n",
      "the front bumper was separate from the rest of the body. This is \n",
      "all I know. If anyone can tellme a model name, engine specs, years\n",
      "of production, where this car is made, history, or whatever info you\n",
      "have on this funky looking car, please e-mail. LABEL: rec.autos\n",
      "TEXT: A fair number of brave souls who upgraded their SI clock oscillator have\n",
      "shared their experiences for this poll. Please send a brief message detailing\n",
      "your experiences with the procedure. Top speed attained, CPU rated speed,\n",
      "add on cards and adapters, heat sinks, hour of usage per day, floppy disk\n",
      "functionality with 800 and 1.4 m floppies are especially requested.\n",
      "\n",
      "I will be summarizing in the next two days, so please add to the network\n",
      "knowledge base if you have done the clock upgrade and haven't answered this\n",
      "poll. Thanks. LABEL: comp.sys.mac.hardware\n",
      "TEXT: well folks, my mac plus finally gave up the ghost this weekend after\n",
      "starting life as a 512k way back in 1985.  sooo, i'm in the market for a\n",
      "new machine a bit sooner than i intended to be...\n",
      "\n",
      "i'm looking into picking up a powerbook 160 or maybe 180 and have a bunch\n",
      "of questions that (hopefully) somebody can answer:\n",
      "\n",
      "* does anybody know any dirt on when the next round of powerbook\n",
      "introductions are expected?  i'd heard the 185c was supposed to make an\n",
      "appearence \"this summer\" but haven't heard anymore on it - and since i\n",
      "don't have access to macleak, i was wondering if anybody out there had\n",
      "more info...\n",
      "\n",
      "* has anybody heard rumors about price drops to the powerbook line like the\n",
      "ones the duo's just went through recently?\n",
      "\n",
      "* what's the impression of the display on the 180?  i could probably swing\n",
      "a 180 if i got the 80Mb disk rather than the 120, but i don't really have\n",
      "a feel for how much \"better\" the display is (yea, it looks great in the\n",
      "store, but is that all \"wow\" or is it really that good?).  could i solicit\n",
      "some opinions of people who use the 160 and 180 day-to-day on if its worth\n",
      "taking the disk size and money hit to get the active display?  (i realize\n",
      "this is a real subjective question, but i've only played around with the\n",
      "machines in a computer store breifly and figured the opinions of somebody\n",
      "who actually uses the machine daily might prove helpful).\n",
      "\n",
      "* how well does hellcats perform?  ;)\n",
      "\n",
      "thanks a bunch in advance for any info - if you could email, i'll post a\n",
      "summary (news reading time is at a premium with finals just around the\n",
      "corner... :( )\n",
      "--\n",
      "Tom Willis  \\  twillis@ecn.purdue.edu    \\    Purdue Electrical Engineering LABEL: comp.sys.mac.hardware\n",
      "TEXT: \n",
      "Do you have Weitek's address/phone number?  I'd like to get some information\n",
      "about this chip.\n",
      " LABEL: comp.graphics\n",
      "TEXT: From article <C5owCB.n3p@world.std.com>, by tombaker@world.std.com (Tom A Baker):\n",
      "\n",
      "\n",
      "My understanding is that the 'expected errors' are basically\n",
      "known bugs in the warning system software - things are checked\n",
      "that don't have the right values in yet because they aren't\n",
      "set till after launch, and suchlike. Rather than fix the code\n",
      "and possibly introduce new bugs, they just tell the crew\n",
      "'ok, if you see a warning no. 213 before liftoff, ignore it'. LABEL: sci.space\n"
     ]
    }
   ],
   "source": [
    "# First 5 training samples:\n",
    "for record in dataset[\"train\"].to_list()[:5]:\n",
    "    print(f\"TEXT: {record['text']} LABEL: {record['label_text']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StaticModelForClassification(\n",
      "  (embeddings): Embedding(29528, 256, padding_idx=0)\n",
      "  (head): Sequential(\n",
      "    (0): Linear(in_features=256, out_features=512, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=512, out_features=2, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Define the staticmodel\n",
    "model = StaticModelForClassification.from_pretrained()\n",
    "# Optional arguments:\n",
    "# model_name: the name of the base model (defaults to potion-base-8m)\n",
    "# n_layers: the number of layers in the MLP (defaults to 1)\n",
    "# hidden_dim: the number of hidden units (defaults to 512)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's train the model on a subset of examples. We pick the first 1000 examples to train on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n",
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "/Users/stephantulkens/Documents/GitHub/model2vec/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n",
      "/Users/stephantulkens/Documents/GitHub/model2vec/.venv/lib/python3.11/site-packages/torch/optim/lr_scheduler.py:60: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n",
      "  warnings.warn(\n",
      "\n",
      "  | Name  | Type                         | Params | Mode \n",
      "---------------------------------------------------------------\n",
      "0 | model | StaticModelForClassification | 7.7 M  | train\n",
      "---------------------------------------------------------------\n",
      "7.7 M     Trainable params\n",
      "0         Non-trainable params\n",
      "7.7 M     Total params\n",
      "30.922    Total estimated model params size (MB)\n",
      "6         Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2351ba8c0b53458fb680e8d29e0f0a6c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |                                                                             | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/stephantulkens/Documents/GitHub/model2vec/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
      "/Users/stephantulkens/Documents/GitHub/model2vec/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
      "/Users/stephantulkens/Documents/GitHub/model2vec/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (29) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33583b55b2194b72b61bbf7c927821fb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |                                                                                    | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "06d6a6682ac749169810a1b62db74c4b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "40e9df4d364a4f759babfadb33d3298f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e1a9dbaf3b14a4c8d40ab58b512b431",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "76ec75ab437d484d9cfac571c3e48b36",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "23296b05b0844c51956dffe0f2f37ded",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1eb5d063142347709f6914206c9b99cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "44ad933a5e1e444ca08dd947a506f90b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f8a09730417a418897c606cf28e1a550",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2a3323088d934f1fad4582ca0002d227",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e6c64447ca004416875c0a8a9f45840f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dd3285381a9b447d9aaf4dcc28da95e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "56c20fe106b8476793dbd4f8b77f008d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4e19b7d693b042adbf4f1a906b4076e9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6d3c8dd7f34c423b8b76be8339044149",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0b70ca545c8549b289400e3dc79953d4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "58f0bc837ade4816a30c797242ef295c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c797cdacb0a7437a929388283ce0fe98",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "24c02c9316c847ae9f377e36185c3b3e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5000bc2f140447cfa39581ddf2a262ce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                  | 0/? [00:00<?,…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training took 8.715388059616089 seconds\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "# Fit the model on the first 1000 records\n",
    "subset = dataset[\"train\"].select(range(1000))\n",
    "s = time.time()\n",
    "model = model.fit(subset[\"text\"], subset[\"label_text\"])\n",
    "print(f\"training took {time.time() - s} seconds\")\n",
    "# Fit takes many many arguments, check them out!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have trained a classifier in 9 seconds. Nice!\n",
    "\n",
    "Let's take a look at how good it is."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                          precision    recall  f1-score   support\n",
      "\n",
      "             alt.atheism       0.27      0.38      0.31       319\n",
      "           comp.graphics       0.62      0.55      0.58       389\n",
      " comp.os.ms-windows.misc       0.47      0.50      0.48       394\n",
      "comp.sys.ibm.pc.hardware       0.50      0.49      0.49       392\n",
      "   comp.sys.mac.hardware       0.47      0.47      0.47       385\n",
      "          comp.windows.x       0.75      0.57      0.65       395\n",
      "            misc.forsale       0.69      0.75      0.72       390\n",
      "               rec.autos       0.46      0.67      0.54       396\n",
      "         rec.motorcycles       0.69      0.56      0.61       398\n",
      "      rec.sport.baseball       0.73      0.72      0.72       397\n",
      "        rec.sport.hockey       0.82      0.76      0.79       399\n",
      "               sci.crypt       0.60      0.62      0.61       396\n",
      "         sci.electronics       0.42      0.47      0.44       393\n",
      "                 sci.med       0.68      0.75      0.71       396\n",
      "               sci.space       0.61      0.64      0.63       394\n",
      "  soc.religion.christian       0.56      0.72      0.63       398\n",
      "      talk.politics.guns       0.54      0.44      0.49       364\n",
      "   talk.politics.mideast       0.79      0.69      0.74       376\n",
      "      talk.politics.misc       0.39      0.28      0.33       310\n",
      "      talk.religion.misc       0.27      0.12      0.16       251\n",
      "\n",
      "                accuracy                           0.57      7532\n",
      "               macro avg       0.57      0.56      0.56      7532\n",
      "            weighted avg       0.58      0.57      0.57      7532\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "\n",
    "predictions = model.predict(dataset[\"test\"][\"text\"])\n",
    "\n",
    "print(classification_report(dataset[\"test\"][\"label_text\"], predictions))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our model scores 0.57 accuracy. But what does this mean? Let's compare it to a `tf-idf` pipeline from `scikit-learn`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                          precision    recall  f1-score   support\n",
      "\n",
      "             alt.atheism       0.21      0.17      0.19       319\n",
      "           comp.graphics       0.73      0.30      0.43       389\n",
      " comp.os.ms-windows.misc       0.44      0.57      0.50       394\n",
      "comp.sys.ibm.pc.hardware       0.67      0.18      0.28       392\n",
      "   comp.sys.mac.hardware       0.29      0.67      0.41       385\n",
      "          comp.windows.x       0.83      0.41      0.55       395\n",
      "            misc.forsale       0.49      0.79      0.61       390\n",
      "               rec.autos       0.61      0.57      0.59       396\n",
      "         rec.motorcycles       0.80      0.40      0.53       398\n",
      "      rec.sport.baseball       0.29      0.64      0.40       397\n",
      "        rec.sport.hockey       0.77      0.61      0.68       399\n",
      "               sci.crypt       0.71      0.48      0.57       396\n",
      "         sci.electronics       0.32      0.31      0.31       393\n",
      "                 sci.med       0.66      0.31      0.42       396\n",
      "               sci.space       0.82      0.41      0.54       394\n",
      "  soc.religion.christian       0.22      0.84      0.35       398\n",
      "      talk.politics.guns       0.58      0.10      0.18       364\n",
      "   talk.politics.mideast       0.50      0.59      0.54       376\n",
      "      talk.politics.misc       1.00      0.01      0.02       310\n",
      "      talk.religion.misc       1.00      0.00      0.01       251\n",
      "\n",
      "                accuracy                           0.43      7532\n",
      "               macro avg       0.60      0.42      0.41      7532\n",
      "            weighted avg       0.59      0.43      0.42      7532\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.pipeline import make_pipeline\n",
    "\n",
    "sklearn_pipeline = make_pipeline(TfidfVectorizer(), LogisticRegression())\n",
    "sklearn_pipeline.fit(subset[\"text\"], subset[\"label_text\"])\n",
    "predictions = sklearn_pipeline.predict(dataset[\"test\"][\"text\"])\n",
    "\n",
    "print(classification_report(dataset[\"test\"][\"label_text\"], predictions))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pretty good! We outperform the tf-idf pipeline by a wide margin.\n",
    "\n",
    "We can now export the model to scikit-learn, and push it to the hub. But first, let's verify whether the predictions of this model and the original model match."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                          precision    recall  f1-score   support\n",
      "\n",
      "             alt.atheism       0.27      0.38      0.31       319\n",
      "           comp.graphics       0.62      0.55      0.58       389\n",
      " comp.os.ms-windows.misc       0.47      0.50      0.48       394\n",
      "comp.sys.ibm.pc.hardware       0.50      0.49      0.49       392\n",
      "   comp.sys.mac.hardware       0.47      0.47      0.47       385\n",
      "          comp.windows.x       0.75      0.57      0.65       395\n",
      "            misc.forsale       0.69      0.75      0.72       390\n",
      "               rec.autos       0.46      0.67      0.54       396\n",
      "         rec.motorcycles       0.69      0.56      0.61       398\n",
      "      rec.sport.baseball       0.73      0.72      0.72       397\n",
      "        rec.sport.hockey       0.82      0.76      0.79       399\n",
      "               sci.crypt       0.60      0.62      0.61       396\n",
      "         sci.electronics       0.42      0.47      0.44       393\n",
      "                 sci.med       0.68      0.75      0.71       396\n",
      "               sci.space       0.61      0.64      0.63       394\n",
      "  soc.religion.christian       0.56      0.72      0.63       398\n",
      "      talk.politics.guns       0.54      0.44      0.49       364\n",
      "   talk.politics.mideast       0.79      0.69      0.74       376\n",
      "      talk.politics.misc       0.39      0.28      0.33       310\n",
      "      talk.religion.misc       0.27      0.12      0.16       251\n",
      "\n",
      "                accuracy                           0.57      7532\n",
      "               macro avg       0.57      0.56      0.56      7532\n",
      "            weighted avg       0.58      0.57      0.57      7532\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pipeline = model.to_pipeline()\n",
    "\n",
    "predictions = pipeline.predict(dataset[\"test\"][\"text\"])\n",
    "\n",
    "print(classification_report(dataset[\"test\"][\"label_text\"], predictions))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ok, so let's save the model locally, or push it to the hub!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline.save_pretrained(\"my_cool_model\")\n",
    "# Fill in your own org\n",
    "# pipeline.push_to_hub(\"my_org/my_model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This saves a model to a local folder. The model can then be loaded as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_model = StaticModelPipeline.from_pretrained(\"my_cool_model\")\n",
    "# Or from the hub\n",
    "# model = StaticModelPipeline.from_pretrained(\"my_org/my_model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One reason to work like this is that the `StaticModelPipeline` does not require torch to be installed at all, leading to really fast cold start predictions, smaller images, and a lot less hassle overall.\n",
    "\n",
    "And that's it! Super fast, super small, super good classifiers."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
